def solve(matrix):
    m = len(matrix)
    n = len(matrix[0])
    rowHasZeroSet = set()
    colHasZeroSet = set()
    for i in range(m):
        for j in range(n):
            if matrix[i][j] == 0:
                rowHasZeroSet.add(i)
                colHasZeroSet.add(j)

    for i in range(m):
        for j in range(n):
            if i in rowHasZeroSet or j in colHasZeroSet:
                matrix[i][j] = 0

    return matrix


if __name__ == "__main__":
    matrix = [[1, 1, 1], [1, 0, 1], [1, 1, 1]]
    matrix = solve(matrix)
    for i in range(len(matrix)):
        print(matrix[i])
